import torch

if __name__ == "__main__":
    inps = torch.randn([3, 3, 3])
    
    # print(torch.sum(inps, dim=0))
    # print(torch.sum(inps, dim=1))
    # print(torch.sum(inps, dim=2))
    
    softmax = torch.nn.Softmax(dim=0)
    
    out = softmax(inps)
    print(inps)
    print(out)
    print(torch.sum(out, dim=0))
    


    